{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d66f42fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: An illegal reflective access operation has occurred\n",
      "WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/home/alexey/spark/spark-3.0.3-bin-hadoop3.2/jars/spark-unsafe_2.12-3.0.3.jar) to constructor java.nio.DirectByteBuffer(long,int)\n",
      "WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform\n",
      "WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n",
      "WARNING: All illegal access operations will be denied in a future release\n",
      "22/02/21 22:25:30 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
      "Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n"
     ]
    }
   ],
   "source": [
    "import pyspark\n",
    "from pyspark.sql import SparkSession\n",
    "\n",
    "spark = SparkSession.builder \\\n",
    "    .master(\"local[*]\") \\\n",
    "    .appName('test') \\\n",
    "    .getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "646fc343",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "[Stage 0:>                                                          (0 + 1) / 1]\r",
      "\r",
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "df_green = spark.read.parquet('data/pq/green/*/*')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "196cccd5",
   "metadata": {},
   "source": [
    "```\n",
    "SELECT \n",
    "    date_trunc('hour', lpep_pickup_datetime) AS hour, \n",
    "    PULocationID AS zone,\n",
    "\n",
    "    SUM(total_amount) AS amount,\n",
    "    COUNT(1) AS number_records\n",
    "FROM\n",
    "    green\n",
    "WHERE\n",
    "    lpep_pickup_datetime >= '2020-01-01 00:00:00'\n",
    "GROUP BY\n",
    "    1, 2\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "74fe52cb",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "rdd = df_green \\\n",
    "    .select('lpep_pickup_datetime', 'PULocationID', 'total_amount') \\\n",
    "    .rdd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1a0bf382",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "fa2b00f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "start = datetime(year=2020, month=1, day=1)\n",
    "\n",
    "def filter_outliers(row):\n",
    "    return row.lpep_pickup_datetime >= start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "69dd326d",
   "metadata": {},
   "outputs": [],
   "source": [
    "rows = rdd.take(10)\n",
    "row = rows[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "cd4b7006",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Row(lpep_pickup_datetime=datetime.datetime(2020, 1, 16, 19, 49, 27), PULocationID=260, total_amount=14.3)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "row"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "d99eb089",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_for_grouping(row): \n",
    "    hour = row.lpep_pickup_datetime.replace(minute=0, second=0, microsecond=0)\n",
    "    zone = row.PULocationID\n",
    "    key = (hour, zone)\n",
    "    \n",
    "    amount = row.total_amount\n",
    "    count = 1\n",
    "    value = (amount, count)\n",
    "\n",
    "    return (key, value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "cb328a44",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_revenue(left_value, right_value):\n",
    "    left_amount, left_count = left_value\n",
    "    right_amount, right_count = right_value\n",
    "    \n",
    "    output_amount = left_amount + right_amount\n",
    "    output_count = left_count + right_count\n",
    "    \n",
    "    return (output_amount, output_count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "2ea260f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import namedtuple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "7dae6064",
   "metadata": {},
   "outputs": [],
   "source": [
    "RevenueRow = namedtuple('RevenueRow', ['hour', 'zone', 'revenue', 'count'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "e0a98ee4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def unwrap(row):\n",
    "    return RevenueRow(\n",
    "        hour=row[0][0], \n",
    "        zone=row[0][1],\n",
    "        revenue=row[1][0],\n",
    "        count=row[1][1]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "a09200b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql import types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "5c14d15e",
   "metadata": {},
   "outputs": [],
   "source": [
    "result_schema = types.StructType([\n",
    "    types.StructField('hour', types.TimestampType(), True),\n",
    "    types.StructField('zone', types.IntegerType(), True),\n",
    "    types.StructField('revenue', types.DoubleType(), True),\n",
    "    types.StructField('count', types.IntegerType(), True)\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "56ea72ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_result = rdd \\\n",
    "    .filter(filter_outliers) \\\n",
    "    .map(prepare_for_grouping) \\\n",
    "    .reduceByKey(calculate_revenue) \\\n",
    "    .map(unwrap) \\\n",
    "    .toDF(result_schema) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "4675bd3f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "df_result.write.parquet('tmp/green-revenue')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "255b5503",
   "metadata": {},
   "outputs": [],
   "source": [
    "columns = ['VendorID', 'lpep_pickup_datetime', 'PULocationID', 'DOLocationID', 'trip_distance']\n",
    "\n",
    "duration_rdd = df_green \\\n",
    "    .select(columns) \\\n",
    "    .rdd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "645c3190",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "921e4ef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "rows = duration_rdd.take(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "f50db3eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(rows, columns=columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "5b8ecc53",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['VendorID',\n",
       " 'lpep_pickup_datetime',\n",
       " 'PULocationID',\n",
       " 'DOLocationID',\n",
       " 'trip_distance']"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "6766c0f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#model = ...\n",
    "\n",
    "def model_predict(df):\n",
    "#     y_pred = model.predict(df)\n",
    "    y_pred = df.trip_distance * 5\n",
    "    return y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "7437b848",
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_model_in_batch(rows):\n",
    "    df = pd.DataFrame(rows, columns=columns)\n",
    "    predictions = model_predict(df)\n",
    "    df['predicted_duration'] = predictions\n",
    "\n",
    "    for row in df.itertuples():\n",
    "        yield row"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "580b5845",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "df_predicts = duration_rdd \\\n",
    "    .mapPartitions(apply_model_in_batch)\\\n",
    "    .toDF() \\\n",
    "    .drop('Index')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "6055d543",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "[Stage 48:>                                                         (0 + 1) / 1]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------+\n",
      "|predicted_duration|\n",
      "+------------------+\n",
      "|             12.95|\n",
      "|             31.25|\n",
      "|              14.0|\n",
      "|             12.75|\n",
      "|               0.1|\n",
      "|             11.05|\n",
      "|11.299999999999999|\n",
      "|54.349999999999994|\n",
      "|             15.25|\n",
      "|             91.75|\n",
      "|             12.25|\n",
      "|               3.1|\n",
      "|               7.5|\n",
      "|11.899999999999999|\n",
      "| 78.89999999999999|\n",
      "|              4.45|\n",
      "|              23.2|\n",
      "|              4.85|\n",
      "|              6.65|\n",
      "|              15.1|\n",
      "+------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "df_predicts.select('predicted_duration').show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e91d243",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
